{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02113.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02113.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ckE-mWntOloZ",
        "colab_type": "text"
      },
      "source": [
        "## 10.7 文本情感分类：使用循环神经网络"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "53Dnm6h4O8Rb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import collections \n",
        "import os \n",
        "import random \n",
        "import tarfile \n",
        "import torch \n",
        "from torch import nn \n",
        "import torchtext.vocab as Vocab \n",
        "import torch.utils.data as Data \n",
        "import d2l \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-VJeUAEdddV",
        "colab_type": "text"
      },
      "source": [
        "### 10.7.1 文本情感分类数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "17Hz61XcdoyN",
        "colab_type": "text"
      },
      "source": [
        "#### 1 读取数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rWhMzkmWd0vy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir -p ../data/calImdb"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s_fgta3Nd9Ej",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install mxnet"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MrhF6_gUe3pz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from mxnet.gluon import utils as gutils"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zy9cWRNXe_pr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "30b2279b-46ea-4675-c538-dbdbecafd826"
      },
      "source": [
        "def download_imdb(data_dir='../data'):\n",
        "    url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')\n",
        "    sha1 = '01ada507287d82875905620988597833ad4e0903'\n",
        "    fname = gutils.download(url, data_dir, sha1_hash=sha1)\n",
        "    with tarfile.open(fname, 'r') as f:\n",
        "        f.extractall(data_dir)\n",
        "\n",
        "download_imdb()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading ../data/aclImdb_v1.tar.gz from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz...\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8qWVr4yVfEd7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def read_imdb(folder='train'):  \n",
        "    data = []\n",
        "    for label in ['pos', 'neg']:\n",
        "        folder_name = os.path.join('../data/aclImdb/', folder, label)\n",
        "        for file in os.listdir(folder_name):\n",
        "            with open(os.path.join(folder_name, file), 'rb') as f:\n",
        "                review = f.read().decode('utf-8').replace('\\n', '').lower()\n",
        "                data.append([review, 1 if label == 'pos' else 0])\n",
        "    random.shuffle(data)\n",
        "    return data\n",
        "\n",
        "train_data, test_data = read_imdb('train'), read_imdb('test')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNy7_Y1JfLtB",
        "colab_type": "text"
      },
      "source": [
        "#### 2 预处理数据"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DKZA3ilSfYBx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_tokenized_imdb(data):\n",
        "    def tokenizer(text):\n",
        "        return [tok.lower() for tok in text.split(' ')]\n",
        "    return [tokenizer(review) for review, _ in data]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kv_bcJzKftHR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "ca343400-bc29-49fd-aed4-7b6aa3cdd392"
      },
      "source": [
        "def get_vocab_imdb(data):\n",
        "    tokenized_data = get_tokenized_imdb(data)\n",
        "    counter = collections.Counter([tk for st in tokenized_data for tk in st])\n",
        "    return Vocab.Vocab(counter, min_freq=5)\n",
        "\n",
        "vocab = get_vocab_imdb(train_data)\n",
        "'# words in vocab:', len(vocab)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "('# words in vocab:', 46151)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IkZKGV4zgSCy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def preprocess_imdb(data, vocab):\n",
        "    max_l = 500 \n",
        "    \n",
        "    def pad(x):\n",
        "        return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))\n",
        "\n",
        "    tokenized_data = get_tokenized_imdb(data)\n",
        "    features = torch.tensor([pad([vocab.stoi[word] for word in words]) for words in tokenized_data])\n",
        "    labels = torch.tensor([score for _, score in data])\n",
        "    return features, labels "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8bknD2xrhf7t",
        "colab_type": "text"
      },
      "source": [
        "#### 3 创建数据迭代器"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_SLh2CcyhkoM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 64 \n",
        "train_set = Data.TensorDataset(*preprocess_imdb(train_data, vocab))\n",
        "test_set = Data.TensorDataset(*preprocess_imdb(test_data, vocab))\n",
        "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
        "test_iter = Data.DataLoader(test_set, batch_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qWzpKXfIiH65",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "6c8a7837-093f-49cd-fa88-ad7639ca2fa7"
      },
      "source": [
        "for X, y in train_iter:\n",
        "    print('X', X.shape, 'y', y.shape)\n",
        "    break \n",
        "'#batches:', len(train_iter)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "X torch.Size([64, 500]) y torch.Size([64])\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "('#batches:', 391)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZzDLkop9kMeN",
        "colab_type": "text"
      },
      "source": [
        "### 10.7.2 使用循环神经网络的模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3bD7c0DSiVJI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BiRNN(nn.Module):\n",
        "    def __init__(self, vocab, embed_size, num_hiddens, num_layers):\n",
        "        super(BiRNN, self).__init__()\n",
        "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
        "        self.encoder = nn.LSTM(input_size=embed_size, \n",
        "                    hidden_size=num_hiddens, \n",
        "                    num_layers=num_layers, \n",
        "                    bidirectional=True)\n",
        "        self.decoder = nn.Linear(4*num_hiddens, 2)\n",
        "\n",
        "    def forward(self, inputs):\n",
        "        embeddings = self.embedding(inputs.permute(1, 0))\n",
        "        outputs, _ = self.encoder(embeddings)\n",
        "        encoding = torch.cat((outputs[0], outputs[-1]), -1)\n",
        "        outs = self.decoder(encoding)\n",
        "        return outs "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_YAt6Rrnjhyh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "embed_size, num_hiddens, num_layers = 100, 100, 2 \n",
        "net = BiRNN(vocab, embed_size, num_hiddens, num_layers)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_HQnvhdKjuQ4",
        "colab_type": "text"
      },
      "source": [
        "#### 1 加载预训练的词向量"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dNcTlW10kZiC",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "3d9a6ced-cb7e-4070-c8da-bc019bc69511"
      },
      "source": [
        "glove_vocab = Vocab.GloVe(name='6B', dim=100)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            ".vector_cache/glove.6B.zip: 862MB [06:28, 2.22MB/s]                           \n",
            "100%|█████████▉| 398074/400000 [00:14<00:00, 27746.52it/s]"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lxcn6Cibkmk7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def load_pretrained_embedding(words, pretrained_vocab):\n",
        "    embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0])\n",
        "    oov_count = 0  # out of vocabulary \n",
        "    for i, word in enumerate(words):\n",
        "        try:\n",
        "            idx = pretrained_vocab.stoi[word]\n",
        "            embed[i, :] = pretrained_vocab.vectors[idx]\n",
        "        except KeyError:\n",
        "            oov_count += 0 \n",
        "    if oov_count > 0:\n",
        "        print('There are %d oov words.')\n",
        "    return embed \n",
        "\n",
        "net.embedding.weight.data.copy_(\n",
        "    load_pretrained_embedding(vocab.itos, glove_vocab)\n",
        ")\n",
        "net.embedding.weight.requires_grad = False "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ybX80y2xm1MN",
        "colab_type": "text"
      },
      "source": [
        "#### 2 训练并评价模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gzcCYv7ImURb",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "7a9caf8f-9333-452c-f434-e8330dda2f65"
      },
      "source": [
        "lr, num_epochs = 0.01, 5 \n",
        "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
        "loss = nn.CrossEntropyLoss()\n",
        "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.5491, train acc 0.709, test acc 0.810, time 42.6 sec\n",
            "epoch 2, loss 0.1984, train acc 0.825, test acc 0.830, time 42.3 sec\n",
            "epoch 3, loss 0.1162, train acc 0.851, test acc 0.837, time 42.0 sec\n",
            "epoch 4, loss 0.0775, train acc 0.869, test acc 0.844, time 41.9 sec\n",
            "epoch 5, loss 0.0542, train acc 0.888, test acc 0.849, time 41.8 sec\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8-I8DgPUmu7X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def predict_sentiment(net, vocab, sentence):\n",
        "    device = list(net.parameters())[0].device \n",
        "    sentence = torch.tensor([vocab.stoi[word] for word in sentence], device=device)\n",
        "    label = torch.argmax(net(sentence.view((1, -1))), dim=1)\n",
        "    return 'positive' if label.item() == 1 else 'negative'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4zE_HvEdnval",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "4b3936c5-0c5e-48a5-d06e-e4eb64f8e1be"
      },
      "source": [
        "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'positive'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2qVjyQs6n7iu",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "c0439194-19f3-4411-a0fc-fbe58f6d8738"
      },
      "source": [
        "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'negative'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    }
  ]
}